#include <iostream>
#include <time.h>
#include <mpi.h>
using namespace std;
#define SIZE 4
typedef struct
{
    double A[SIZE][SIZE];
    double B[SIZE][SIZE];
} AandB;
int main(void)
{

    const int m = SIZE, n = SIZE, k = SIZE;
    int comm_sz; //进程的数量
    int my_rank; //进程的编号
    MPI_Init(NULL, NULL);
    MPI_Comm_size(MPI_COMM_WORLD, &comm_sz);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
    const int row_range = m / comm_sz;
    int C_size = row_range * k;
    double(*B)[k] = new double[n][k];
    double(*A)[n] = new double[m][n];
    double(*C)[k] = new double[m][k];
    double(*buf_C)[k] = new double[row_range][k];

    AandB aandb;

    if (my_rank == 0)
    {
        for (int i = 0; i < m; i++)
        {
            for (int j = 0; j < n; j++)
            {
                A[i][j] = i * m + j + 1;
                aandb.A[i][j] = A[i][j];
            }
        }
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < k; j++)
            {
                B[i][j] = i * n + j + 1;
                aandb.B[i][j] = B[i][j];
            }
        }
    }
    int count = 2;
    int blocklengths[count] = {m * n, n * k};
    MPI_Aint offsets[count];
    MPI_Get_address(aandb.A, &offsets[0]);
    MPI_Get_address(aandb.B, &offsets[1]);
    offsets[1] = offsets[1] - offsets[0];
    offsets[0] = 0;
    MPI_Datatype types[count] = {MPI_DOUBLE, MPI_DOUBLE};
    MPI_Datatype aandbtype;
    MPI_Type_create_struct(count, blocklengths, offsets, types, &aandbtype);
    MPI_Type_commit(&aandbtype);

    double start, end;
    MPI_Barrier(MPI_COMM_WORLD);
    start = MPI_Wtime();

    MPI_Bcast(&aandb, 1, aandbtype, 0, MPI_COMM_WORLD);
    MPI_Barrier(MPI_COMM_WORLD);

    for (int i = my_rank*row_range; i < my_rank*row_range+row_range; i++)
    {
        for (int j = 0; j < k; j++)
        {
            buf_C[i-my_rank*row_range][j] = 0;
            for (int l = 0; l < n; l++)
            {
                buf_C[i-my_rank*row_range][j] += aandb.B[l][j] * aandb.A[i][l];
            }
        }
    }
    MPI_Gather(&buf_C[0][0], C_size, MPI_DOUBLE, &C[0][0], C_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);

    MPI_Barrier(MPI_COMM_WORLD);
    end = MPI_Wtime();
    if (my_rank == 0)
    {
        if (SIZE <= 8)
        {
            for (int i = 0; i < m; i++)
            {
                for (int j = 0; j < k; j++)
                {
                    printf("%lf ", C[i][j]);
                }
                putchar('\n');
            }
        }
        printf("%lf s\n", (double)(end - start));
    }

    MPI_Finalize();
    delete[] A;
    delete[] B;
    delete[] C;
    delete[] buf_C;
    return 0;
}